import os
import fnmatch

from pathlib import Path


def get_bkgnd_img_path(mask_path, img_basedir, bkgnd_roi_channel, bkgnd_img_ext='.png'):
    """
    Retrieve path to background image for given mask

    Parameters
    ----------
    mask_path: str
        Path fo mask file
    img_basedir: str
        Directory to search for background image
    bkgnd_roi_channel: str
        Channel to use for background (e.g., 'AF488')
    bkgnd_img_ext: str
        Extension of background image file (default='.png')

    Returns
    -------
    bkgnd_img_path: str
       Path to background image for given mask

    """

    # Get background image filename
    mask_fname = os.path.splitext(os.path.basename(mask_path))[0]
    img_fname = mask_fname.replace('_mask', '')
    channel_name = img_fname.split('_')[-1]  # get channel name from last string item
    bkgnd_img_fname = img_fname.replace(channel_name, bkgnd_roi_channel) + bkgnd_img_ext

    # Find path of background image
    bkgnd_img_path = []
    for item in Path(img_basedir).rglob(bkgnd_img_fname):
        bkgnd_img_path.append(str(item))

    if len(bkgnd_img_path) != 1:
        raise ValueError('Found %d background channel image for %s'
                         % (len(bkgnd_img_path), mask_fname))

    return bkgnd_img_path[0]


def get_orig_img_path(mask_path, img_basedir, img_ext='.png', mask_key='_mask'):
    """
    From the input path to a mask file, search through the image directory to find the path to the corresponding image
    file.

    Parameters
    ----------
    mask_path: str
        A string containing the filename of the mask. Can contain the directory preceding the filename.
    img_basedir: str
        The base directory containing the image files.
    img_ext: str
        Image extension. Default='.png'
    mask_key: str
        The keyword in the mask filename the signifies that it is a mask file. Default='_mask'.

    Returns
    -------
    orig_img_path: str
        The path to the corresponding image file.
    """

    # Get original image filename
    mask_fname = os.path.splitext(os.path.basename(mask_path))[0]
    img_fname = mask_fname.replace(mask_key, '') + img_ext

    # Find path of original image
    orig_img_path = []
    for item in Path(img_basedir).rglob(img_fname):
        orig_img_path.append(str(item))

    if len(orig_img_path) != 1:
        raise ValueError('Found ' + str(len(orig_img_path)) + ' original image for ' +
                         mask_fname)

    return orig_img_path[0]


def get_roi_excl_path(mask_path, excl_roi_dir):
    """
    Retrieve path to ROI exclusion file

    Parameters
    ----------
    mask_path: str
        Path fo mask file
    excl_roi_dir: str
        Directory to search for ImageJ ROI file. Extracted ROI will be excluded from background ROI

    Returns
    -------
    bkgnd_excl_roi_path: str
        Path to ROI exclude file
    """

    mask_fname = os.path.splitext(os.path.basename(mask_path))[0]
    mask_id_name = mask_fname.split('_')[0:3]  # The first 3 string items contain mouse#, slide#, and brain region info
    mask_id_name = '_'.join(mask_id_name)

    bkgnd_excl_roi_path = []
    for item in Path(excl_roi_dir).rglob(mask_id_name + '*'):
        bkgnd_excl_roi_path.append(str(item))

    if len(bkgnd_excl_roi_path) > 1:
        raise ValueError('Found %d .roi file for ROI exclusion for %s. Should be 0 or 1.'
                         % (len(bkgnd_excl_roi_path), mask_fname))
    elif len(bkgnd_excl_roi_path) == 1:
        bkgnd_excl_roi_path = bkgnd_excl_roi_path[0]

    return bkgnd_excl_roi_path


def confirm_mask_paths_get_sample_id(mask_basepath, channels, mask_key='_mask'):
    """
    Search through input directory 'mask_basepath' to confirm that for all specified channels, the corresponding mask
    file can be found. Returns error otherwise.

    Parameters
    ----------
    mask_basepath: str
        Directory where mask files (.npy) are located
    channels: list of str
        List of channel names that are in mask filenames
    mask_key: str
        A string that signifies that the file is a mask file to pattern match against. Default = '_mask'

    Returns
    -------
    mask_paths, sample_id:
        Paths to individual mask files, and a list of individual sample IDs
    """

    make_filename_pattern = '*mask.npy'  # pattern match for file name containing mask

    mask_basepath = Path(mask_basepath)
    if not mask_basepath.exists():
        raise ValueError("Failed to find mask directory, check directory spelling.")

    # Retrieve paths of all mask files found in mask directory
    mask_paths = []
    for i in mask_basepath.rglob(make_filename_pattern):
        mask_paths.append(str(i))
    if len(mask_paths) == 0:
        raise ValueError("Failed to find any '*mask.npy' files in mask directory.")

    # Generate a list of samples ID that have masks across different colour channels
    sample_id = []
    for path_i in mask_paths:
        fname = os.path.splitext(os.path.basename(path_i))[0]

        if fname.find(mask_key) != 0:
            fname = fname.replace(mask_key, '')  # remove the mask key characters from file name

            for ch in channels:
                ch_str_len = len(ch)  # channel string length
                fname_ch_id = fname.find(ch)  # find index of channel in file name
                if fname_ch_id != -1:
                    fname_ch_rmv = fname.replace(ch, '')  # remove channel character from file name
                    sample_id.append(fname_ch_rmv)
    sample_id = list(set(sample_id))  # pull out unique items
    sample_id.sort()

    # Tidy up samples ID name
    sample_id_tidy = []
    for i in sample_id:
        i = i.replace('_ ', '_')
        i = i.replace('__', '_')
        i = i.replace('  ', ' ')
        sample_id_tidy.append(i)
    sample_id = sample_id_tidy

    # For each sample ID (Subject ID + slide ID), check that mask files for each channel can be found in mask directory
    for sample_i in sample_id:
        found_mask_all_ch = False

        # Create a list of the sample ID's keyword by taking out space/_ delimiters
        sample_i_keyword = []
        i_keyword = sample_i.split('_')  # retrieve keywords by splitting '_'
        for i in i_keyword:
            i_keyword_2 = i.split()  # retrieve keywords by splitting ' '
            for j in i_keyword_2:
                sample_i_keyword.append(j)

        for ch in channels:
            sample_i_keyword_ch = sample_i_keyword.copy()
            sample_i_keyword_ch.append(ch)  # add channel name to list of keywords
            found_mask_ch = False

            for mask_path in mask_paths:
                fname = os.path.splitext(os.path.basename(mask_path))[0]
                has_all_keyword = True
                for i in sample_i_keyword_ch:
                    if not (i in fname):
                        has_all_keyword = False

                if has_all_keyword is True:
                    found_mask_ch = True

            if found_mask_ch is False:
                raise ValueError('Mask file not found for ' + '"' + sample_i + '" for channel "' + ch + '"')

    print("Mask file found for all channels. OK to proceed.")

    return mask_paths, sample_id


def get_mask_paths_sample_id_depreciated(mask_basepath, channels):
    """
    **DEPRECIATED. Use 'confirm_mask_paths_get_sample_id' instead**
     Search through input directory to look for mask files of specified channels.

    Parameters
    ----------
    mask_basepath: str
        Directory where mask files (.npy) are located
    channels: list of str
        List of channel names that are in mask filenames

    Returns
    -------
    mask_paths, samples_id:
        Paths to individual mask files, and a list of individual sample IDs
    """

    mask_basepath = Path(mask_basepath)
    if not mask_basepath.exists():
        raise ValueError("Failed to find mask directory, check directory spelling.")

    # Retrieve paths of all mask files found in mask directory
    mask_paths = []
    for item in mask_basepath.rglob('*mask.npy'):
        mask_paths.append(str(item))
    if len(mask_paths) == 0:
        raise ValueError("Failed to find any '*mask.npy' files in mask directory.")

    # Generate a list of samples ID that have masks across different colour channels
    samples_id = []
    for item in mask_paths:
        fname = os.path.splitext(os.path.basename(item))[0]
        channel = fname.split('_')[-2]  # pull out channel
        sample_i = fname.replace('_' + channel + '_mask', '')  # delete channel and mask strings
        samples_id.append(sample_i)
    samples_id = list(set(samples_id))  # pull out unique items
    samples_id.sort()

    # Check that mask files for each channel can be found in mask directory
    for sample_i in samples_id:
        for ch in channels:
            mask_fname = sample_i + '_' + ch + '_mask.npy'
            have_mask = False
            for mask_path in mask_paths:
                if mask_fname in mask_path:
                    have_mask = True
            if not have_mask:
                raise ValueError("Mask file not found for " + mask_fname)
    print("Mask file found for all channels. OK to proceed.")

    return mask_paths, samples_id


def img_raw_filename_change(img_basedir, img_ext, name_id_search, slide_nums, channels, multi_slide_wildcard):
    """
    Changes image filename from raw filename from confocal to minimalist filename for downstream
    processing. New filename contains only name ID, slide number, and channel name.

    Parameters
    ----------
    img_basedir: str
        Base directory containing images
    img_ext: str
        Image extension, e.g., '.png' or '.tif'
    name_id_search:
        Pattern matching unique name ID, e.g., subject ID
    slide_nums:
        List of slide numbers for pattern matching, e.g., ['S1', 'S2', 'S3']. Case sensitive
    channels:
        List of channel for pattern matching, e.g., ['DAPI', 'AF488', 'AF555', 'AF680']
    multi_slide_wildcard: str
        String to match to filename for multi-slides/scenes, using wildcard matching. Case sensitive

    Returns
    -------
    None
    """

    for file in Path(img_basedir).rglob('*' + img_ext):

        fname = os.path.splitext(file.name)[0]
        fname_split = fname.replace('-', ' ').replace('_', ' ').split(' ')
        fname_split_slidenum = fname.replace('-', ' ').split(' ')

        name_id = [i for i in fname_split if name_id_search in i]
        if len(name_id) == 0:
            raise ValueError("'" + name_id_search + "' was not found in filename '" + fname + "'")
        else:
            name_id = name_id[0]  # use 1st instance of name_id, if multiple instances exist

        curr_slide = 'none'
        curr_channel = 'none'

        # process slide number
        for slide_num in slide_nums:
            for i in fname_split_slidenum:

                # process multi-slide case first (with multi scenes)
                if 'multi_slide_wildcard' in locals() and fnmatch.fnmatchcase(i, multi_slide_wildcard):
                    slide_name = i

                    # process left vs right (L vs R) slide
                    if slide_name[-1] == 'L':
                        slide_side = 'L'
                    elif slide_name[-1] == 'R':
                        slide_side = 'R'
                    else:
                        slide_side = ''

                    # process multi-slide numbering
                    multi_slide_num = [char for char in slide_name if char.isdigit()]
                    for j in fname_split_slidenum:
                        if fnmatch.fnmatchcase(j,
                                               'stitch_s?_*'):  # 'stitch_s?_*' indexes if image if from scene 1, 2 etc
                            scene_num = j.replace('_', ' ').split(' ')[1]  # select s1, s2, s3... etc
                            scene_num = scene_num[1:]
                            if scene_num.isdigit():
                                scene_num = int(scene_num)
                                curr_slide = 'S' + multi_slide_num[scene_num - 1] + slide_side
                            else:
                                raise ValueError("Error: Cannot process multi-scene file name for file '" + fname + "'")

                # next process single slide case (with single scene)
                else:
                    fname_slidenum_resplit = i.replace('_', ' ').split(' ')

                    for j in fname_slidenum_resplit:
                        if slide_num in j:
                            curr_slide = j

        # process colour channel name
        for channel in channels:
            for i in fname_split:
                if channel in i:
                    curr_channel = channel

        if curr_slide == 'none':
            raise ValueError("Error: Filename does not contain slide number ID for file '" + fname + "'")
        elif curr_channel == 'none':
            raise ValueError("Error: Filename does not contain channel name ID for file '" + fname + "'")

        new_name = name_id + "_" + curr_slide + "_" + curr_channel + img_ext
        new_path = Path(str(file.parent) + '/' + new_name)
        print("Renaming to: " + str(new_path))
        file.rename(new_path)


def find_mask_path_w_sample_id(mask_paths, sample_i, channel, mask_keyword='mask'):
    """
    Given the sample ID and channel, find the corresponding path for its segmented mask file.

    Parameters
    ----------
    mask_paths: str
        Directory where mask files (.npy) are located.
    sample_i: str
        Sample ID (Subject ID + slide ID), for which the mask file is to be searched.
    channel: str
        Channel name, for which the mask file is to be searched.
    mask_keyword:
        The keyword in the filename that signifies that it is a mask file. Default='mask'.

    Returns
    -------
    my_mask_path: str
        The path of the mask file corresponding to the Sample ID and Channel.
    """

    # Create a list of the sample ID's keyword by taking out space/_ delimiters
    sample_i_keyword = []
    i_keyword = sample_i.split('_')  # retrieve keywords by splitting '_'
    for i in i_keyword:
        i_keyword_2 = i.split()  # retrieve keywords by splitting ' '
        for j in i_keyword_2:
            sample_i_keyword.append(j)

    sample_i_keyword_ch = sample_i_keyword.copy()
    sample_i_keyword_ch.append(channel)  # add channel name to list of keywords
    sample_i_keyword_ch.append(mask_keyword)  # add mask keyword to list of keywords
    found_mask_ch = False
    my_mask_path = []

    for mask_path in mask_paths:
        fname = os.path.splitext(os.path.basename(mask_path))[0]
        has_all_keyword = True
        for i in sample_i_keyword_ch:
            if not (i in fname):
                has_all_keyword = False

        if has_all_keyword is True:
            found_mask_ch = True
            my_mask_path = mask_path

    if found_mask_ch is False:
        raise ValueError('Mask file not found for ' + '"' + sample_i + '" for channel "' + channel + '"')

    return my_mask_path
